{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "bd1badee",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Auto-Sklearn cannot be imported.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x12613c510>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import gurobipy as gp\n",
    "from gurobipy import GRB\n",
    "import numpy as np\n",
    "import pyepo\n",
    "from pyepo.model.grb import optGrbModel\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch.utils.data import DataLoader\n",
    "import time as time\n",
    "import torch.optim as optim\n",
    "import torch.nn as nn\n",
    "from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR\n",
    "from ModelsKnapSack import KnapSackNet, ValPredictNet\n",
    "from generate_knapsack_data import Gen_Knapsack_data\n",
    "from knapsack_utils import RegretLoss\n",
    "from Trainer import trainer\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Set the random seed\n",
    "seed = 42\n",
    "torch.manual_seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e8b2b6e5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating training data for knapsack problem with 2 knapsacks and 10 items\n",
      "Restricted license - for non-production use only - expires 2024-10-28\n",
      "Optimizing for optDataset...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████| 800/800 [00:00<00:00, 1511.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Optimizing for optDataset...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████| 100/100 [00:00<00:00, 1011.77it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Optimizing for optDataset...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████| 100/100 [00:00<00:00, 972.06it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finished building dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# Data parameters\n",
    "num_knapsack = 2\n",
    "num_item = 10\n",
    "num_feat = 10\n",
    "num_data = 1000\n",
    "# Load Data\n",
    "state = Gen_Knapsack_data(num_data, num_feat, num_item, num_knapsack)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "dc29dca0",
   "metadata": {},
   "outputs": [],
   "source": [
    "weights_numpy = state[\"weights_numpy\"]\n",
    "capacities = state[\"capacities\"]\n",
    "dataset_train = state[\"dataset_train\"]\n",
    "dataset_test = state[\"dataset_test\"]\n",
    "contexts_numpy = state[\"contexts_numpy\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "81ff9290",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/danielmckenzie/My-Drive/Research/Fixed_Point_Networks/Diff-Opt-Over-Polytopes-Project/SPO_with_DYS/source/Knapsack/dYS_opt_net.py:29: UserWarning: The operator 'aten::linalg_svd' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/mps/MPSFallback.mm:11.)\n",
      "  U, s, VT = torch.linalg.svd(self.A, full_matrices=False)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "KnapSackNet(\n",
       "  (fc_1): Linear(in_features=10, out_features=1000, bias=True)\n",
       "  (fc_2): Linear(in_features=1000, out_features=1000, bias=True)\n",
       "  (fc_3): Linear(in_features=1000, out_features=10, bias=True)\n",
       "  (leaky_relu): LeakyReLU(negative_slope=0.1)\n",
       "  (dropout): Dropout(p=0.3, inplace=False)\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Initialize DYS_net\n",
    "device = 'mps'\n",
    "weights = torch.Tensor(weights_numpy).to(device)\n",
    "capacities_DYS = capacities.to(device)\n",
    "DYS_net = KnapSackNet(weights, capacities_DYS, num_knapsack, num_item, num_feat, device=device)\n",
    "DYS_net.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5bb22a24",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Initial validation loss is  0.27018412947654724\n",
      "epoch:  1 validation loss is  0.1253541111946106 epoch time:  1.4141261577606201\n",
      "epoch:  2 validation loss is  0.1253541111946106 epoch time:  1.9979557991027832\n",
      "epoch:  3 validation loss is  0.1253541111946106 epoch time:  2.3797860145568848\n",
      "epoch:  4 validation loss is  0.1253541111946106 epoch time:  2.3984780311584473\n",
      "epoch:  5 validation loss is  0.1253541111946106 epoch time:  2.3953583240509033\n",
      "epoch:  6 validation loss is  0.1253541111946106 epoch time:  2.404285192489624\n",
      "epoch:  7 validation loss is  0.1253541111946106 epoch time:  2.3834598064422607\n",
      "epoch:  8 validation loss is  0.1253541111946106 epoch time:  2.322476863861084\n",
      "epoch:  9 validation loss is  0.1253541111946106 epoch time:  2.374821901321411\n",
      "epoch:  10 validation loss is  0.1253541111946106 epoch time:  2.3892669677734375\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[5], line 4\u001b[0m\n\u001b[1;32m      2\u001b[0m max_epochs \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m20\u001b[39m\n\u001b[1;32m      3\u001b[0m learning_rate \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1e-3\u001b[39m\n\u001b[0;32m----> 4\u001b[0m val_loss_hist, epoch_time_hist, best_test_loss, time_till_best_val_loss \u001b[38;5;241m=\u001b[39m \u001b[43mtrainer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mDYS_net\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataset_train\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataset_test\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataset_test\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_item\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m      5\u001b[0m \u001b[43m                         \u001b[49m\u001b[43mnum_knapsack\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmax_epochs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlearning_rate\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel_type\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mDYS\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m      6\u001b[0m \u001b[43m                         \u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/My-Drive/Research/Fixed_Point_Networks/Diff-Opt-Over-Polytopes-Project/SPO_with_DYS/source/Knapsack/Trainer.py:136\u001b[0m, in \u001b[0;36mtrainer\u001b[0;34m(net, train_dataset, test_dataset, val_dataset, num_item, num_knapsack, max_epochs, learning_rate, model_type, device)\u001b[0m\n\u001b[1;32m    133\u001b[0m epoch_time_hist\u001b[38;5;241m.\u001b[39mappend(epoch_time)\n\u001b[1;32m    135\u001b[0m \u001b[38;5;66;03m## Now compute loss on validation set\u001b[39;00m\n\u001b[0;32m--> 136\u001b[0m val_loss \u001b[38;5;241m=\u001b[39m \u001b[43mCompute_Test_Loss\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnet\u001b[49m\u001b[43m,\u001b[49m\u001b[43mloader_test\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel_type\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmetric\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_knapsack\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_item\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    138\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m val_loss \u001b[38;5;241m<\u001b[39m best_val_loss:\n\u001b[1;32m    139\u001b[0m     state_save_name \u001b[38;5;241m=\u001b[39m checkpt_path\u001b[38;5;241m+\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbest.pth\u001b[39m\u001b[38;5;124m'\u001b[39m\n",
      "File \u001b[0;32m~/My-Drive/Research/Fixed_Point_Networks/Diff-Opt-Over-Polytopes-Project/SPO_with_DYS/source/Knapsack/knapsack_utils.py:34\u001b[0m, in \u001b[0;36mCompute_Test_Loss\u001b[0;34m(net, loader, model_type, metric, num_knapsack, num_item, device)\u001b[0m\n\u001b[1;32m     32\u001b[0m opt_sol \u001b[38;5;241m=\u001b[39m opt_sol\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m     33\u001b[0m opt_value \u001b[38;5;241m=\u001b[39m opt_value\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[0;32m---> 34\u001b[0m predicted \u001b[38;5;241m=\u001b[39m \u001b[43mnet\u001b[49m\u001b[43m(\u001b[49m\u001b[43md_batch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     35\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m model_type \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDYS\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m     36\u001b[0m     loss \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m metric(w_batch, predicted[:,:\u001b[38;5;241m-\u001b[39m(num_knapsack \u001b[38;5;241m+\u001b[39m num_item)], opt_sol, opt_value, eval_mode\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\u001b[38;5;241m.\u001b[39mitem()\n",
      "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1190\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1191\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1192\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1193\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1195\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "File \u001b[0;32m~/My-Drive/Research/Fixed_Point_Networks/Diff-Opt-Over-Polytopes-Project/SPO_with_DYS/source/Knapsack/dYS_opt_net.py:117\u001b[0m, in \u001b[0;36mDYS_opt_net.forward\u001b[0;34m(self, d, eps, max_depth, depth_warning)\u001b[0m\n\u001b[1;32m    114\u001b[0m       \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mproject_C1(z)\u001b[38;5;241m.\u001b[39mdetach()\n\u001b[1;32m    116\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 117\u001b[0m    \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtest_time_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[43md\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/My-Drive/Research/Fixed_Point_Networks/Diff-Opt-Over-Polytopes-Project/SPO_with_DYS/source/Knapsack/ModelsKnapSack.py:82\u001b[0m, in \u001b[0;36mKnapSackNet.test_time_forward\u001b[0;34m(self, d)\u001b[0m\n\u001b[1;32m     80\u001b[0m solutions \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mzeros(w\u001b[38;5;241m.\u001b[39mshape, device\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[1;32m     81\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(batch_size):\n\u001b[0;32m---> 82\u001b[0m   \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mknapsack_solver\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msetObj\u001b[49m\u001b[43m(\u001b[49m\u001b[43mw\u001b[49m\u001b[43m[\u001b[49m\u001b[43mi\u001b[49m\u001b[43m,\u001b[49m\u001b[43m:\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnum_resources\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     83\u001b[0m   solution, _ \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mknapsack_solver\u001b[38;5;241m.\u001b[39msolve()\n\u001b[1;32m     84\u001b[0m   zero_padding \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mzeros((\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mzero_padding_dim), device\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice)\n",
      "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pyepo/model/grb/grbmodel.py:46\u001b[0m, in \u001b[0;36moptGrbModel.setObj\u001b[0;34m(self, c)\u001b[0m\n\u001b[1;32m     44\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(c) \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_cost:\n\u001b[1;32m     45\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSize of cost vector cannot match vars.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m---> 46\u001b[0m obj \u001b[38;5;241m=\u001b[39m \u001b[43mgp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mquicksum\u001b[49m\u001b[43m(\u001b[49m\u001b[43mc\u001b[49m\u001b[43m[\u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mx\u001b[49m\u001b[43m[\u001b[49m\u001b[43mk\u001b[49m\u001b[43m]\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mi\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mk\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43menumerate\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     47\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_model\u001b[38;5;241m.\u001b[39msetObjective(obj)\n",
      "File \u001b[0;32msrc/gurobipy/gurobi.pxi:3706\u001b[0m, in \u001b[0;36mgurobipy.quicksum\u001b[0;34m()\u001b[0m\n",
      "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pyepo/model/grb/grbmodel.py:46\u001b[0m, in \u001b[0;36m<genexpr>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m     44\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(c) \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_cost:\n\u001b[1;32m     45\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSize of cost vector cannot match vars.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m---> 46\u001b[0m obj \u001b[38;5;241m=\u001b[39m gp\u001b[38;5;241m.\u001b[39mquicksum(c[i] \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mx[k] \u001b[38;5;28;01mfor\u001b[39;00m i, k \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mx))\n\u001b[1;32m     47\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_model\u001b[38;5;241m.\u001b[39msetObjective(obj)\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "# Train\n",
    "max_epochs = 20\n",
    "learning_rate = 1e-3\n",
    "val_loss_hist, epoch_time_hist, best_test_loss, time_till_best_val_loss = trainer(DYS_net, dataset_train, dataset_test, dataset_test, num_item,\n",
    "                         num_knapsack, max_epochs, learning_rate, model_type = \"DYS\",\n",
    "                         device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d40b6d0",
   "metadata": {},
   "source": [
    "# Preparing the benchmark approaches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2a26c11",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(seed)\n",
    "val_net = ValPredictNet(num_knapsack, num_item, num_feat, weights,\n",
    "                        capacities, device=device)\n",
    "val_net.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3aca9550",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_loss_hist_2stage, epoch_time_hist_2stage, best_test_loss_2stage = trainer(val_net, dataset_train, dataset_test, num_item,\n",
    "                         num_knapsack, max_epochs, learning_rate, \n",
    "                         model_type = \"Two-stage\", device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc7efc7f",
   "metadata": {},
   "source": [
    "## Implementing the SPO+ Approach"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ab3eb54",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(seed)\n",
    "val_net2 = ValPredictNet(num_knapsack, num_item, num_feat, weights,\n",
    "                        capacities, device=device)\n",
    "val_net2.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9dd13be",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(seed)\n",
    "test_loss_hist_SPO, epoch_time_hist_SPO, best_test_loss_SPO = trainer(val_net2, dataset_train, dataset_test, num_item,\n",
    "                         num_knapsack, max_epochs, learning_rate, \n",
    "                         model_type = \"SPO+\", device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f2f1aa83",
   "metadata": {},
   "source": [
    "## Differentiable Blackbox Optimization Approach\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39ae0b88",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(seed)\n",
    "val_net3 = ValPredictNet(num_knapsack, num_item, num_feat, weights,\n",
    "                        capacities, device=device)\n",
    "val_net3.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2fe97dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_loss_hist_BBOpt, epoch_time_hist_BBOpt, best_test_loss_SPO = trainer(val_net3, dataset_train, dataset_test, num_item,\n",
    "                         num_knapsack, max_epochs, learning_rate, \n",
    "                         model_type = \"BBOpt\", device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c729eb4",
   "metadata": {},
   "source": [
    "## Perterbed Optimizer Approach"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91937f2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(seed)\n",
    "val_net4 = ValPredictNet(num_knapsack, num_item, num_feat, weights,\n",
    "                        capacities, device=device)\n",
    "val_net4.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b201a33e",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_loss_hist_PertOpt, epoch_time_hist_PertOpt, best_test_loss_PertOpt = trainer(val_net4, dataset_train, dataset_test, num_item,\n",
    "                         num_knapsack, max_epochs, learning_rate, \n",
    "                         model_type = \"PertOpt\", device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9aaadbc9",
   "metadata": {},
   "source": [
    "## Fenchel-Young Approach"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2c1c1cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(seed)\n",
    "val_net5 = ValPredictNet(num_knapsack, num_item, num_feat, weights,\n",
    "                        capacities, device=device)\n",
    "val_net5.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9b10fc0",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_loss_hist_PertOpt_FY, epoch_time_hist_PertOpt_FY, best_test_loss_FY = trainer(val_net5, dataset_train, dataset_test, num_item,\n",
    "                         num_knapsack, max_epochs, learning_rate, \n",
    "                         model_type = \"PertOpt-FY\", device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87fc5f90",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(test_loss_hist_DYS, label='DYS')\n",
    "plt.plot(test_loss_hist_2stage, label='Two stage')\n",
    "plt.plot(test_loss_hist_SPO, label='SPO+')\n",
    "plt.plot(test_loss_hist_BBOpt, label=\"BBOpt\")\n",
    "plt.plot(test_loss_hist_PertOpt, label=\"PertOpt\")\n",
    "plt.plot(test_loss_hist_PertOpt_FY, label=\"PertOpt-FY\")\n",
    "plt.legend()\n",
    "plt.title(\"Regret vs epochs\")\n",
    "plt.savefig('Knapsack_regret.png')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3cc2a8f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(learning_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec20d599",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(np.cumsum(epoch_time_hist_DYS), label='DYS')\n",
    "plt.plot(np.cumsum(epoch_time_hist_2stage), label='Two stage')\n",
    "plt.plot(np.cumsum(epoch_time_hist_SPO), label='SPO+')\n",
    "plt.plot(np.cumsum(epoch_time_hist_BBOpt), label='BBOpt')\n",
    "plt.plot(np.cumsum(epoch_time_hist_PertOpt), label='PertOpt')\n",
    "plt.plot(np.cumsum(epoch_time_hist_PertOpt_FY), label='PertOpt-FY')\n",
    "plt.legend()\n",
    "plt.title(\"Cumulative Training Time vs Epochs\")\n",
    "plt.savefig('Knapsack_train_time.png')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75e9f6ac",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
